{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "collapsed": true,
    "id": "8TOhzssTBZAb"
   },
   "source": [
    "# Finding Synonyms and Analogies\n",
    "\n",
    "This notebook is taken from a [PyTorch NLP tutorial](https://github.com/joosthub/pytorch-nlp-tutorial-ny2018/blob/master/day_1/0_Using_Pretrained_Embeddings.ipynb) source: [repository for the training tutorial as the 2018 O'Reilly AI Conference in NYC on April 29 and 30, 2018](https://github.com/joosthub/pytorch-nlp-tutorial-ny2018)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Colab SETUP\n",
    "#!pip install annoy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HhJ0hUC5BZAe"
   },
   "outputs": [],
   "source": [
    "from annoy import AnnoyIndex\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import notebook\n",
    "import os\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dhHjwm2JBZAl"
   },
   "source": [
    "Glove embeddings can be downloaded from [GloVe webpage](https://nlp.stanford.edu/projects/glove/).\n",
    "\n",
    "You need to uncomment the appropriate part in the following cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "f76c_EM1Ch3q"
   },
   "outputs": [],
   "source": [
    "## Colab SETUP\n",
    "#!mkdir data\n",
    "#%cd data\n",
    "#!wget http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip\n",
    "#!unzip glove.6B.zip\n",
    "#ROOT_DIR = 'content'\n",
    "## local SETUP download glove in ~/data/ with the commands wget http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip\n",
    "## and unzip glove.6B.zip\n",
    "#ROOT_DIR = Path.home()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NWUdh2BxBZAn"
   },
   "outputs": [],
   "source": [
    "data_path = os.path.join(ROOT_DIR,'data/')\n",
    "file = 'glove.6B.100d.txt'\n",
    "glove_filename=data_path+file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4GX0hecoBZAr"
   },
   "outputs": [],
   "source": [
    "def load_word_vectors(filename):\n",
    "    word_to_index = {}\n",
    "    word_vectors = []\n",
    "    \n",
    "    with open(filename) as fp:\n",
    "        for line in notebook.tqdm(fp.readlines(), leave=False):\n",
    "            line = line.split(\" \")\n",
    "            \n",
    "            word = line[0]\n",
    "            word_to_index[word] = len(word_to_index)\n",
    "            \n",
    "            vec = np.array([float(x) for x in line[1:]])\n",
    "            word_vectors.append(vec)\n",
    "            \n",
    "    return word_to_index, word_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lDUMIr66BZAz"
   },
   "outputs": [],
   "source": [
    "word_to_index, word_vectors = load_word_vectors(glove_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jvIOgxSRBZA8"
   },
   "outputs": [],
   "source": [
    "len(word_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3iaYd15JBZBB"
   },
   "outputs": [],
   "source": [
    "word_vectors[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IlkJErGvBZBG"
   },
   "outputs": [],
   "source": [
    "word_to_index['beautiful']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y3K8WuWEBZBL"
   },
   "outputs": [],
   "source": [
    "class PreTrainedEmbeddings(object):\n",
    "    def __init__(self, glove_filename):\n",
    "        self.word_to_index, self.word_vectors = load_word_vectors(glove_filename)\n",
    "        self.word_vector_size = len(self.word_vectors[0])\n",
    "        \n",
    "        self.index_to_word = {v: k for k, v in self.word_to_index.items()}\n",
    "        self.index = AnnoyIndex(self.word_vector_size, metric='euclidean')\n",
    "        print('Building Index')\n",
    "        for _, i in notebook.tqdm(self.word_to_index.items(), leave=False):\n",
    "            self.index.add_item(i, self.word_vectors[i])\n",
    "        self.index.build(50)\n",
    "        print('Finished!')\n",
    "    \n",
    "    def get_embedding(self, word):\n",
    "        return self.word_vectors[self.word_to_index[word]]\n",
    "    \n",
    "    def closest(self, word, n=1):\n",
    "        vector = self.get_embedding(word)\n",
    "        nn_indices = self.index.get_nns_by_vector(vector, n)\n",
    "        return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
    "    \n",
    "    def closest_v(self, vector, n=1):\n",
    "        nn_indices = self.index.get_nns_by_vector(vector, n)\n",
    "        return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
    "    \n",
    "    def sim(self, w1, w2):\n",
    "        return np.dot(self.get_embedding(w1), self.get_embedding(w2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bZzQW7pyBZBP"
   },
   "outputs": [],
   "source": [
    "glove = PreTrainedEmbeddings(glove_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Stpp4FNaBZBT"
   },
   "outputs": [],
   "source": [
    "glove.closest('apple', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ARYYKOwcBZBa"
   },
   "outputs": [],
   "source": [
    "glove.closest('chip', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "e0-ywu08BZBe"
   },
   "outputs": [],
   "source": [
    "glove.closest('baby', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AtS8hOuPBZBm"
   },
   "outputs": [],
   "source": [
    "glove.closest('beautiful', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jhPAEA5lBZBs"
   },
   "outputs": [],
   "source": [
    "def SAT_analogy(w1, w2, w3):\n",
    "    '''\n",
    "    Solves problems of the type:\n",
    "    w1 : w2 :: w3 : __\n",
    "    '''\n",
    "    closest_words = []\n",
    "    try:\n",
    "        w1v = glove.get_embedding(w1)\n",
    "        w2v = glove.get_embedding(w2)\n",
    "        w3v = glove.get_embedding(w3)\n",
    "        w4v = w3v + (w2v - w1v)\n",
    "        closest_words = glove.closest_v(w4v, n=5)\n",
    "        closest_words = [w for w in closest_words if w not in [w1, w2, w3]]\n",
    "    except:\n",
    "        pass\n",
    "    if len(closest_words) == 0:\n",
    "        print(':-(')\n",
    "    else:\n",
    "        print('{} : {} :: {} : {}'.format(w1, w2, w3, closest_words[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rlxpUPIhBZBw"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('man', 'he', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "07UCvv3fBZB4"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('fly', 'plane', 'sail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gOHAs9gIBZB9"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('beijing', 'china', 'tokyo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VdpkkLgNBZCC"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('man', 'woman', 'son')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "M019lyRbBZCH"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('man', 'doctor', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Dmuw3B3zBZCL"
   },
   "outputs": [],
   "source": [
    "SAT_analogy('woman', 'leader', 'man')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qeWK5ASVBZCP"
   },
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "include_colab_link": false,
   "name": "08_Playing_with_word_embedding.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
